Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend, MLIR] Support indexing of the dynamically shaped arrays #411

Merged
merged 14 commits into from
Jan 8, 2024

Conversation

sergei-mironov
Copy link
Contributor

@sergei-mironov sergei-mironov commented Dec 20, 2023

In this PR we enable indexing for tensors with dynamic shapes. This PR is intended to be merged after the #370

  • Indexing with constant
  • Indexing with variable
  • Modifications

[sc-47632]

The corresponding Jax PR suggests the fix to the upstream.

Copy link

codecov bot commented Dec 20, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (628a381) 99.56% compared to head (ae82b8b) 99.56%.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #411   +/-   ##
=======================================
  Coverage   99.56%   99.56%           
=======================================
  Files          43       43           
  Lines        7643     7646    +3     
  Branches      512      512           
=======================================
+ Hits         7610     7613    +3     
  Misses         17       17           
  Partials       16       16           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sergei-mironov sergei-mironov self-assigned this Dec 20, 2023
@maliasadi maliasadi added the frontend Pull requests that update the frontend label Dec 21, 2023
@sergei-mironov
Copy link
Contributor Author

Hi @rmoyard , could you please review the MLIR-related parts of this PR?

@sergei-mironov sergei-mironov marked this pull request as ready for review December 22, 2023 11:59
@sergei-mironov sergei-mironov changed the title [Frontend] Support indexing of the dynamically shaped arrays [Frontend, MLIR] Support indexing of the dynamically shaped arrays Dec 22, 2023
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice patch 👍

frontend/catalyst/utils/jax_extras.py Outdated Show resolved Hide resolved
@josh146 josh146 mentioned this pull request Jan 4, 2024
@sergei-mironov sergei-mironov changed the base branch from dynshape-quantum-primitives to main January 8, 2024 11:40
@sergei-mironov
Copy link
Contributor Author

sergei-mironov commented Jan 8, 2024

@josh146 I meant this PR, when I asked about CodeFactor problems. The 'ComplexMethod' checks lack diagnostics so I suggest disabling these.

@sergei-mironov sergei-mironov merged commit f9c5a1f into main Jan 8, 2024
18 of 19 checks passed
@sergei-mironov sergei-mironov deleted the dynshape-indexing branch January 8, 2024 13:19
with Patcher((jax._src.interpreters.partial_eval, "get_aval", get_aval2)), ExitStack():
with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.lax.slicing, "gather_p", gather2_p),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dime10 @grwlf here we have an issue, the new rule does not define jvp and therefore it is not compatible with grad or jacobian transformations ad.defjvp(gather_p, _gather_jvp_rule, None) see the original gather_p

gather_p = standard_primitive(
    _gather_shape_rule, _gather_dtype_rule, 'gather',
    weak_type_rule=_argnum_weak_type(0))
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
pe.padding_rules[gather_p] = _gather_pad_rule

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmoyard does this result in a user-facing bug?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you cannot use jax.grad inside qjit when a gather operation is created, for example slicing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so #305?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That issue is about Catalyst gradients, I think Romain is talking about JAX gradients (which are run in the frontend on the jaxpr, hence the primitives need gradient rules).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, got it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly what David said, qjit with jax.grad and slicing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qjit(jax.grad(f))(x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking if we have a resolution on this particular comment thread :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grwlf How is the upstream PR looking for this patch?
In the meantime, can we just attach the original gradient rule to the patched primitive?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants